#include <cstdlib>
#include <tuple>
#include "cell.h"
#include "comm.h"
#include "elemtab.h"
template<size_t alignment>
constexpr size_t align_to(size_t off){
  static_assert((alignment & -alignment) == alignment);
  return (off + alignment - 1) & ~(alignment - 1);
}
template<typename T>
size_t aligned_copyn(char *target, size_t off, T *src, int n) {
  size_t off_aligned = align_to<alignof(T)>(off);
  // printf("write: %d\n", off_aligned);
  memcpy(target + off_aligned, src, sizeof(T) * n);
  return off_aligned + sizeof(T) * n;
}

const size_t ALIGN = 4L * 1024;
const size_t BSIZE = 4L * 1024 *1024;
struct FileOutBuf {
  char buf[BSIZE + sizeof(celldata_t)];
  size_t off;
  FILE *f;
  FileOutBuf(FILE *f) : f(f), off(0){};
  template<typename T>
  void append(T *src, int n) {
    off = aligned_copyn(buf, off, src, n);
    if (off > BSIZE) {
      fwrite_unlocked(buf, 1, BSIZE, f);
      memcpy(buf, buf + BSIZE, off - BSIZE);
      off = 0;
    }
  }
  void flush() {
    off = align_to<ALIGN>(off);
    fwrite_unlocked(buf, 1, off, f);
    fflush_unlocked(f);
  }
};
struct FileInBuf {
  char buf[BSIZE];
  size_t off;
  FILE *f;
  FileInBuf(FILE *f) : f(f), off(0){
    fread_unlocked(buf, 1, BSIZE, f);
  }
  void extract_char(char *dst, size_t size) {
    size_t rest = size;
    char *ptr = dst;
    while (rest > 0) {
      size_t avail = BSIZE - off;
      size_t ncopy = min(avail, rest);
      memcpy(ptr, buf + off, ncopy);
      rest -= ncopy;
      ptr += ncopy;
      off += ncopy;
      if (off == BSIZE) {
        fread_unlocked(buf, 1, BSIZE, f);
        off = 0;
      }
    }
  }
  template<typename T>
  void extract(T *dst, int n) {
    off = align_to<alignof(T)>(off);
    // printf("read: %ld\n", off);
    extract_char((char*)dst, sizeof(T) * n);
  }
};
template<int NEED_V>
void pack_cell_traj(FileOutBuf &obuf, celldata_t *cell){
  long natom = cell->natom;
  obuf.append(&natom, 1);
  obuf.append(&cell->basis, 1);
  obuf.append(cell->tag, cell->natom);
  obuf.append(cell->t, cell->natom);
  obuf.append(cell->x, cell->natom);
  if (NEED_V)
    obuf.append(cell->v, cell->natom);
}
void pack_cell_full_restart(FileOutBuf &obuf, celldata_t *cell) {
  long natom = cell->natom;

  obuf.append(&natom, 1);
  obuf.append(cell->q, natom);
  obuf.append(cell->x, natom);
  obuf.append(cell->tag, natom);
  obuf.append(cell->v, natom);
  obuf.append(cell->mass, natom);
  obuf.append(cell->rmass, natom);
  obuf.append(cell->t, natom);
  obuf.append(cell->rigid, natom);

}

void pack_cell_minimal_restart(FileOutBuf &obuf, celldata_t *cell) {
  long natom = cell->natom;
  obuf.append(&natom, 1);
  obuf.append(cell->q, natom);
  obuf.append(cell->x, natom);
  obuf.append(cell->tag, natom);
  obuf.append(cell->v, natom);
  obuf.append(cell->mass, natom);
  obuf.append(cell->t, natom);

}

void unpack_cell_minimal_restart(FileInBuf &ibuf, celldata_t *cell) {
  long natom;
  ibuf.extract(&natom, 1);
  cell->natom = natom;
  ibuf.extract(cell->q, natom);
  ibuf.extract(cell->x, natom);
  ibuf.extract(cell->tag, natom);
  ibuf.extract(cell->v, natom);
  ibuf.extract(cell->mass, natom);
  ibuf.extract(cell->t, natom);

  for (int i = 0; i < natom; i ++) {
    cell->rmass[i] = 1. / cell->mass[i];
  }
}
template<void (*PACK)(FileOutBuf &obuf, celldata_t *cell)>
void write_data(FileOutBuf &obuf, cellgrid_t *grid){  
  FOREACH_LOCAL_CELL(grid, ii, jj, kk, cell) {
    // printf("%d %d %d %d\n", ii, jj, kk, cell->natom);
    PACK(obuf, cell);
  }
}
template<void (*UNPACK)(FileInBuf &ibuf, celldata_t *cell)>
void read_data(FileInBuf &ibuf, cellgrid_t *grid){  
  // puts("reading data");
  FOREACH_LOCAL_CELL(grid, ii, jj, kk, cell) {
    UNPACK(ibuf, cell);
  }
}
enum restart_status {
  RESTART_INCOMPLETE,
  RESTART_MINIMAL,
  RESTART_FULL
};
struct grid_header {
  long status;
  box<real> lbox, gbox;
  vec<int> nlocal, nall;
  int nn;
  box<int> dim;
  vec<real> skin, len, rlen;
  real rcut;
};
struct elem_entry {
  int itype;
  char stype[8];
};
#include <climits>
void write_elemtab(const char *path, elemtab_t *elemtab) {
  FILE *f = fopen(path, "wb");
  FileOutBuf obuf(f);
  int ntypes = elemtab->cnt;
  obuf.append(&ntypes, 1);
  //for (auto ite = elemtab->begin(); ite != elemtab->end(); ite ++) {
  for (auto &elem : *elemtab) {
    elem_entry e;
    //printf("%p\n", &elem);
    strcpy(e.stype, elem.key);
    e.itype = elem.val;
    obuf.append(&e, 1);
  }
  obuf.flush();
  fclose(f);
}
void read_elemtab(const char *path, elemtab_t *elemtab) {
  FILE *f = fopen(path, "rb");
  FileInBuf ibuf(f);
  int ntypes;
  ibuf.extract(&ntypes, 1);
  for (int i = 0; i < ntypes; i ++) {
    elem_entry e;
    ibuf.extract(&e, 1);
    (*elemtab)[e.stype] = e.itype;
  }
  fclose(f);
}
void write_grid(const char *path, cellgrid_t *grid){
  grid_header hdr;
  hdr.status = RESTART_INCOMPLETE;
  hdr.lbox = grid->lbox;
  hdr.gbox = grid->gbox;
  hdr.nlocal = grid->nlocal;
  hdr.nall = grid->nall;
  hdr.nn = grid->nn;
  hdr.dim = grid->dim;
  hdr.skin = grid->skin;
  hdr.len = grid->len;
  hdr.rlen = grid->rlen;
  hdr.rcut = grid->rcut;
  // printf("writing restart %s\n", path);
  FILE *rstf = fopen(path, "wb");
  FileOutBuf obuf(rstf);
  obuf.append(&hdr, 1);

  write_data<pack_cell_minimal_restart>(obuf, grid);
  obuf.flush();
  fseek(rstf, 0, SEEK_SET);
  hdr.status = RESTART_MINIMAL;
  fwrite_unlocked(&hdr, sizeof(grid_header), 1, rstf);
  fclose(rstf);
}
void read_grid(const char *path, cellgrid_t *grid) {
  grid_header hdr;
  FILE *rstf = fopen(path, "rb");
  FileInBuf ibuf(rstf);
  ibuf.extract(&hdr, 1);
  assert(hdr.status == RESTART_MINIMAL);
  
  // fread_unlocked(&hdr, sizeof(grid_header), 1, rstf);
  grid->lbox = hdr.lbox;
  grid->gbox = hdr.gbox;
  grid->nlocal = hdr.nlocal;
  grid->nall = hdr.nall;
  grid->nn = hdr.nn;
  grid->dim = hdr.dim;
  grid->skin = hdr.skin;
  grid->len = hdr.len;
  grid->rlen = hdr.rlen;
  grid->rcut = hdr.rcut;

  // printf("%d %d %d\n", grid->nlocal.x, grid->nlocal.y, grid->nlocal.z);
  // fseek(rstf, 0, SEEK_SET);
  read_data<unpack_cell_minimal_restart>(ibuf, grid);
  fclose(rstf);
}
#if defined(WIN32) || defined(_WIN32) 
#define PATH_SEP "\\" 
#else 
#define PATH_SEP "/" 
#endif 
void write_restart(const char *prefix, int itape, long istep, mpp_t *mpp, cellgrid_t *grid) {
  FILE *metadata;
  char meta_path[PATH_MAX];
  sprintf(meta_path, "%s-%d.meta", prefix, itape);
  FILE *meta_file;
  /* Truncate meta file first, to avoid reading unfinished restart */
  if (mpp->pid == 0) meta_file = fopen(meta_path, "w");
  MPI_Barrier(mpp->comm);
  char grid_file[PATH_MAX];
  sprintf(grid_file, "%s-%d-%06d.esgrid", prefix, itape, mpp->pid);
  write_grid(grid_file, grid);
  MPI_Barrier(mpp->comm);
  /* Finish the metadata file*/
  fprintf(meta_file, "%d\n", istep);
  fclose(meta_file);
}

long check_step(const char *prefix, int itape) {
  char meta_path[PATH_MAX];
  sprintf(meta_path, "%s-%d.meta", prefix, itape);
  FILE *meta_file = fopen(meta_path, "r");
  if (meta_file == NULL)
    return -1;
  long cur_step;
  if (fscanf(meta_file, "%ld", &cur_step) != 1) cur_step = -2;
  return cur_step;
}
std::tuple<int, long> select_tape(const char *prefix, mpp_t *mpp) {
  int itape;
  long istep;
  if (mpp->pid == 0) {
    long tape0_step = check_step(prefix, 0);
    long tape1_step = check_step(prefix, 1);
    /* No restart file available */
    if (tape0_step < 0 && tape1_step < 0) return std::make_tuple(-1, -1L);
    if (tape0_step > tape1_step) {
      itape = 0;
      istep = tape0_step;
    } else {
      itape = 1;
      istep = tape1_step;
    }
  }
  long buf[2];
  if (mpp->pid == 0) {
    buf[0] = itape;
    buf[1] = istep;
  }
  MPI_Bcast(buf, 2, MPI_LONG, 0, mpp->comm);
  itape = buf[0];
  istep = buf[1];
  return std::make_tuple(itape, istep);
}
long read_restart(const char *prefix, mpp_t *mpp, cellgrid_t *grid, elemtab_t *elemtab) {
  int itape;
  long istep;
  std::tie(itape, istep) = select_tape(prefix, mpp);
  if (itape == -1) return -1L;
  char elemtab_path[PATH_MAX];
  sprintf(elemtab_path, "%s.esatab", prefix);
  read_elemtab(elemtab_path, elemtab);

  char grid_file[PATH_MAX];
  sprintf(grid_file, "%s-%d-%06d.esgrid", prefix, itape, mpp->pid);
  puts(grid_file);
  read_grid(grid_file, grid);
  return istep;
}

void write_trajectory(const char *path, long istep, cellgrid_t *grid){
  FILE *f = fopen(path, "ab");
  FileOutBuf obuf(f);
  long ncell = grid->nlocal.vol();
  obuf.append(&istep, 1);
  obuf.append(&ncell, 1);
  write_data<pack_cell_traj<1>>(obuf, grid);
  obuf.flush();
  fclose(f);
}
